"""
Harmonic analysis with longer time window - 3 years (can be repeated for 1 year as well).
This Python script is run in NASA HECC and 254 jobs are submitted - one for each altimeter track. 

Author: Badarvada Yadidya
Created on: August 2023

"""
# Import necessary libraries
import numpy as np
import xarray as xr
import dask.array as da
import dask
from dask import delayed
import pandas as pd
import sys
import cmaps as cmp
import os
import datetime
import matplotlib.pyplot as plt
import utide as ut

# Open the two zarr files containing the altimetry data 
ds = xr.open_zarr('/nobackup/ybadarva/diurnal/altim_nt_d.zarr')
ds3 = xr.open_zarr('/nobackup/ybadarva/diurnal/altim_nt.zarr')

# Set the time coordinate for both datasets
ds['time'] = pd.date_range("2017-01-01 10:00", periods=26270,freq='H')
ds3['time'] = pd.date_range("2017-01-01 10:00", periods=26270,freq='H')
time = pd.date_range("2017-01-01 10:00", periods=26270,freq='H')
# Do the same for one year time window but select only year of data instead of 3.

# Open the nlyss_info.nc dataset, which contains information about the altimetry tracks
ds2 = xr.open_dataset('/nfs/turbo/lsa-arbic/arinne/finalized/Step02/outputs/nlyss_info.nc')

# Interpolate the lat coordinate in the ds2 dataset and fill in any missing values
lat = ds2.lat.T.interpolate_na(dim='np').interpolate_na(dim='nt')
lat = lat.ffill('nt').bfill('nt').ffill('np').bfill('np')
lat = lat.where((lat < -1) | (lat > 1), 1)
# Doing the above to avoid this error:
# /home/yadidya/.conda/envs/yadi/lib/python3.10/site-packages/utide/harmonics.py:129: RuntimeWarning: divide by zero encountered in scalar divide
#  rr[j] *= 0.36309 * (1.0 - 5.0 * slat**2) / slat
#/home/yadidya/.conda/envs/yadi/lib/python3.10/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: invalid value encountered in reduce
#  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)

# Define a function to detrend the data along a single dimension
def detrend_dim(da, dim, deg=1):
    p = da.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(da[dim], p.polyfit_coefficients)  # Use actual values of dimension 'dim'
    return da - fit

#=================================================================================================
# Select 'n' which is the track number - dim:'nt'
n = int(sys.argv[1])

#=================================================================================================
altim = xr.open_dataset('/nfs/turbo/lsa-arbic/arinne/finalized/Step03/outputs/nlyss_altim.nc')

# Define a function to convert MATLAB date numbers to Python datetime objects
def datetime_conv(matlab_dn):
    if np.isnan(matlab_dn):
        return pd.NaT
    else:
        python_datetime = datetime.datetime.fromordinal(int(matlab_dn)) + datetime.timedelta(days=matlab_dn%1) - datetime.timedelta(days = 366)
        return python_datetime

# Extract the time and sla variables from the altim dataset
alti_time = altim.time.T
sla = altim.sla.T
#=================================================================================================
# Extract the tssh, sssh, and fssh variables from the ds dataset
tssh = ds.tssh[:,:,n].expand_dims(dim='nt', axis=2)
sssh = ds.sssh[:,:,n].expand_dims(dim='nt', axis=2)
fssh = (ds.tssh[:,:,n]-ds.fssh[:,:,n]).expand_dims(dim='nt', axis=2)

# Define empty variables and fill them with nan values
tot_sh = np.full((5,136,3305,1), np.nan)
ste_sh = np.full((5,136,3305,1), np.nan)
fil_sh = np.full((5,136,3305,1), np.nan)

constit = ['M2','S2','N2','K1','O1','H1','H2']

tot_am = np.full((len(constit), tssh.shape[1], tssh.shape[2]), np.nan)
ste_am = np.full((len(constit), sssh.shape[1], fssh.shape[2]), np.nan)
fil_am = np.full((len(constit), fssh.shape[1], fssh.shape[2]), np.nan)

tot_ph = np.full((len(constit), tssh.shape[1], tssh.shape[2]), np.nan)
ste_ph = np.full((len(constit), sssh.shape[1], fssh.shape[2]), np.nan)
fil_ph = np.full((len(constit), fssh.shape[1], fssh.shape[2]), np.nan)

# for m in range(2):
for m in range(ds.dims['np']):
    mask1 = np.isnan(tssh[:,m,0].values)
    if np.all(mask1):
        continue
    sssh1 = detrend_dim((sssh[:,m,0] * (100/9.8)),dim='time').values
    coef2 = ut.solve(time,sssh1,lat=lat[m,n].values.item(),constit=constit,verbose=False)
    indx2 = [list(coef2.name).index(item) for item in constit]
    
    fssh1 = detrend_dim((fssh[:,m,0] * (100/9.8)),dim='time').values
    coef3 = ut.solve(time,fssh1,lat=lat[m,n].values.item(),constit=constit,verbose=False)
    indx3 = [list(coef3.name).index(item) for item in constit]

    ste_am[:,m,0] = coef2.A[indx2]
    fil_am[:,m,0] = coef3.A[indx3]

    ste_ph[:,m,0] = coef2.g[indx2]
    fil_ph[:,m,0] = coef3.g[indx3]
    
    t1 = xr.DataArray([datetime_conv(dn) for dn in alti_time[:,m,n].values], dims='time')
    mask = np.isnan(alti_time[:,m,n].values) | np.isnan(sla[:,m,n].values)
    
    # Define a list of constituent sets to use in the analysis
    constit_sets = [
        # ['M2'], ['S2'], ['N2'], ['K2'], ['K1'], ['O1'], ['P1'], ['Q1'], ['H1'], ['H2']
        ['M2','S2'], ['M2','S2','N2'], ['K1','O1'], ['H1','H2'],['M2','S2','K1','O1','H1','H2']
    ]
    
    for i, constit_set in enumerate(constit_sets):
        ste_sh[i,:,m,0] = ut.reconstruct(t1, coef2, constit=constit_set,verbose=False)['h']
        fil_sh[i,:,m,0] = ut.reconstruct(t1, coef3, constit=constit_set,verbose=False)['h']
        
        ste_sh[i,:,m,0][mask] = np.nan
        fil_sh[i,:,m,0][mask] = np.nan
    
set_const = ['M2S2','M2S2N2','K1O1','H1H2','M2S2K1O1H1H2']    
# Save the data variables in to xarray dataarrays
ste_sh = xr.DataArray(ste_sh, dims=['nw',sla.dims[0],sla.dims[1],sla.dims[2]],
                      coords=dict(nw=np.array(set_const),nc=sla.coords['nc'],np=sla.coords['np'],nt=np.array([n])),
                      name='sssh')
fil_sh = xr.DataArray(fil_sh, dims=['nw',sla.dims[0],sla.dims[1],sla.dims[2]],
                      coords=dict(nw=np.array(set_const),nc=sla.coords['nc'],np=sla.coords['np'],nt=np.array([n])),
                      name='fssh')

ste_am = xr.DataArray(ste_am, dims=['constit',sssh.dims[1],sssh.dims[2]],
                      coords=dict(constit=np.array(constit),np=sssh.coords['np'],nt=sssh.coords['nt']),
                      name='stam')
fil_am = xr.DataArray(fil_am, dims=['constit',fssh.dims[1],fssh.dims[2]],
                      coords=dict(constit=np.array(constit),np=fssh.coords['np'],nt=fssh.coords['nt']),
                      name='flam')

ste_ph = xr.DataArray(ste_ph, dims=['constit',sssh.dims[1],sssh.dims[2]],
                      coords=dict(constit=np.array(constit),np=sssh.coords['np'],nt=sssh.coords['nt']),
                      name='stph')
fil_ph = xr.DataArray(fil_ph, dims=['constit',fssh.dims[1],fssh.dims[2]],
                      coords=dict(constit=np.array(constit),np=fssh.coords['np'],nt=fssh.coords['nt']),
                      name='flph')

# Set the zero values where land is present to nan
ste_sh = ste_sh.where(ste_sh != 0 , np.nan)
fil_sh = fil_sh.where(tot_sh != 0 , np.nan)
ste_am = ste_am.where(ste_am != 0 , np.nan)
fil_am = fil_am.where(fil_am != 0 , np.nan)
ste_ph = ste_ph.where(ste_ph != 0 , np.nan)
fil_ph = fil_ph.where(fil_ph != 0 , np.nan)

# Merge all the dataarrays to xarray dataset
ds_out = xr.merge([ste_sh,fil_sh,
                  ste_am,fil_am,
                  ste_ph,fil_ph])

ds_out.to_netcdf(f'/nfs/turbo/lsa-arbic/yadidya/hycom/altim_day/out_nt{n}.nc')

# There will be 254 nc files - one along each track. They are combined into a zarr file and then the subset datasets are created for simplified analysis and plotting.